package com.simafei.flow.core.data.impl;

import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.StrUtil;
import com.simafei.flow.core.common.Condition;
import com.simafei.flow.core.common.FieldScope;
import com.simafei.flow.core.common.FieldType;
import com.simafei.flow.core.common.Variable;
import com.simafei.flow.core.data.AggColumn;
import com.simafei.flow.core.data.AggSpec;
import com.simafei.flow.core.data.DataManager;
import com.simafei.flow.core.data.LoadSpec;
import com.simafei.flow.core.data.StoreSpec;
import com.simafei.flow.core.data.ds.DataSourceExecutor;
import com.simafei.flow.core.data.ds.FlowDataSource;
import org.apache.ibatis.jdbc.SQL;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.support.rowset.SqlRowSet;

import javax.sql.DataSource;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
 * @author fengpengju
 */
public class SqlDataManager implements DataManager {

    private final static int BATCH_SIZE = 1000;
    private static final String QUERY_SCHEMA_SQL = "select column_name,column_comment,data_type from INFORMATION_SCHEMA.COLUMNS " +
            "where table_schema=? and table_name=?";

    private final JdbcTemplate jdbcTemplate;

    private final FlowDataSource dataSource;

    public SqlDataManager(FlowDataSource dataSource) {
        this.dataSource = dataSource;
        this.jdbcTemplate = new JdbcTemplate(dataSource);
    }

    @Override
    public List<Map<String, Object>> load(LoadSpec config, Map<String, Object> params) {
        return DataSourceExecutor.call(() -> {
            String select = "*";
            if (CollectionUtil.isNotEmpty(config.getColumns())) {
                select = StrUtil.join(",", config.getColumns());
            }
            Condition.SqlConditionResult result = config.getCriteria().toSqlExpr(params);
            String sql = new SQL()
                    .SELECT(select)
                    .FROM(config.getTableName())
                    .WHERE(result.getWhen())
                    // 一次最多运行加载5000条数据
                    .LIMIT(BATCH_SIZE)
                    .toString();

            return jdbcTemplate.queryForList(sql, result.getArgs().toArray());
        }, config.getDataSource());
    }

    @Override
    public List<Map<String, Object>> aggregate(AggSpec query, Map<String, Object> params) {
        return DataSourceExecutor.call(() -> {
            List<String> columns = new ArrayList<>();
            if (CollectionUtil.isNotEmpty(query.getGroupBy())) {
                columns.addAll(query.getGroupBy());
            }
            if (CollectionUtil.isEmpty(query.getAggColumns())) {
                return List.of();
            }
            for (AggColumn column : query.getAggColumns()) {
                String col = StrUtil.isEmpty(column.getColumn()) ? "*" : column.getColumn();
                columns.add(String.format("%s(%s) as %s", column.getAggFunc().getFunc(), col, column.getAlias()));
            }
            Condition.SqlConditionResult result = query.getCriteria().toSqlExpr(params);
            SQL sql = new SQL()
                    .SELECT(StrUtil.join(",", columns))
                    .FROM(query.getTableName())
                    .WHERE(result.getWhen());
            if (CollectionUtil.isNotEmpty(query.getGroupBy())) {
                sql.GROUP_BY(StrUtil.join(",", query.getGroupBy()));
            }

            return jdbcTemplate.queryForList(sql.toString(), result.getArgs().toArray());
        }, query.getDataSource());
    }

    @Override
    public List<Integer> saveOrUpdate(StoreSpec storeSpec, List<Map<String, Object>> params) {
        return DataSourceExecutor.call(() -> {
            List<Integer> resList = new ArrayList<>();

            Map<String, String> mapping = storeSpec.getMapping();
            if (CollectionUtil.isEmpty(mapping) || CollectionUtil.isEmpty(params)) {
                return resList;
            }

            String[] columns = mapping.keySet().toArray(String[]::new);
            for (Map<String, Object> param : params) {
                List<String> values = new ArrayList<>();
                mapping.forEach((source, target) -> {
                    Object insertValue = param.getOrDefault(target, null);
                    if (insertValue instanceof String) {
                        insertValue = "'" + insertValue + "'";
                    }
                    values.add(insertValue == null ? "null" : insertValue.toString());
                });

                int res;
                // 更新记录
                Condition.SqlConditionResult result = storeSpec.getCriteria().toSqlExpr(param);
                // 先查询是否存在记录，存在则更新，不存在则新增
                String selectSql = getSelectSql(storeSpec, result);
                List<Map<String, Object>> maps = jdbcTemplate.queryForList(selectSql, result.getArgs());

                if (CollectionUtil.isNotEmpty(maps)) {
                    String updateSql = getUpdateSql(storeSpec, columns, result);
                    List<String> args = new ArrayList<>(values);
                    args.addAll(result.getArgs());
                    res = jdbcTemplate.update(updateSql, args.toArray(Object[]::new));
                } else {
                    String insertSql = getInsertSql(storeSpec, columns);
                    res = jdbcTemplate.update(insertSql, values.toArray(Object[]::new));
                }
                resList.add(res);
            }
            return resList;
        }, storeSpec.getDataSource());
    }

    @Override
    public List<String> getDataSourceNames() {
        return dataSource.getDataSourceNames();
    }

    @Override
    public void add(String key, DataSource dataSource) {
        this.dataSource.add(key, dataSource);
    }

    @Override
    public void remove(String key) {
        this.dataSource.remove(key);
    }

    @Override
    public boolean exists(String key) {
        return this.dataSource.exists(key);
    }

    @Override
    public List<String> tables(String dataSource) {
        return DataSourceExecutor.call(() -> jdbcTemplate.query("SHOW FULL TABLES",
                (resultSet, i) -> resultSet.getString(1)), dataSource);
    }

    @Override
    public List<Variable> columns(String dataSource, String database, String tableName) {
        return DataSourceExecutor.call(() -> {
            SqlRowSet resultSet = jdbcTemplate.queryForRowSet(QUERY_SCHEMA_SQL, database, tableName);
            List<Variable> columns = new ArrayList<>();
            while (resultSet.next()) {
                String columnName = resultSet.getString("column_name");
                Variable variable = new Variable();
                variable.setVarName(columnName);
                variable.setVarLabel(resultSet.getString("column_comment"));
                variable.setScope(FieldScope.REQUEST);

                String dataType = resultSet.getString("data_type");
                if (dataType == null) {
                    continue;
                }
                if (dataType.contains("int")) {
                    variable.setVarType(FieldType.INTEGER);
                } else if (dataType.contains("varchar")) {
                    variable.setVarType(FieldType.STRING);
                } else if (dataType.contains("decimal")) {
                    variable.setVarType(FieldType.DECIMAL);
                } else if (dataType.contains("datetime")) {
                    variable.setVarType(FieldType.DATE_TIME);
                } else if ("date".equals(dataType)) {
                    variable.setVarType(FieldType.DATE);
                } else {
                    variable.setVarType(FieldType.STRING);
                }
                columns.add(variable);
            }
            return columns;
        }, dataSource);
    }

    private String getUpdateSql(StoreSpec spec, String[] columns, Condition.SqlConditionResult result) {
        return new SQL()
                .UPDATE(spec.getTableName())
                .SET(getSetSql(columns))
                .WHERE(result.getWhen())
                .toString();
    }

    private String getSelectSql(StoreSpec spec, Condition.SqlConditionResult result) {
        return new SQL()
                .SELECT("*")
                .FROM(spec.getTableName())
                .WHERE(result.getWhen())
                .LIMIT(1)
                .toString();
    }

    private String getInsertSql(StoreSpec spec, String[] columns) {
        String[] values = new String[columns.length];
        for (int i = 0; i < columns.length; i++) {
            values[i] = "?";
        }
        return new SQL()
                .INSERT_INTO(spec.getTableName())
                .INTO_COLUMNS(columns)
                .INTO_VALUES(values)
                .toString();
    }

    private String[] getSetSql(String[] columns) {
        List<String> result = new ArrayList<>();
        for (String column : columns) {
            result.add(column + "=?");
        }
        return result.toArray(new String[0]);
    }
}
